"""Launch the trainer"""
import argparse
import os
import sys
import traceback
from pathlib import Path
from pprint import pprint

import ray

from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline
from trinity.common.config import Config, load_config
from trinity.common.constants import (
    LOG_DIR_ENV_VAR,
    LOG_LEVEL_ENV_VAR,
    LOG_NODE_IP_ENV_VAR,
    PLUGIN_DIRS_ENV_VAR,
)
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.dlc_utils import setup_ray_cluster
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins

logger = get_logger(__name__)


def bench(config: Config) -> None:
    """Evaluate model."""
    config.explorer.name = "benchmark"
    try:
        explorer = Explorer.get_actor(config)
        ray.get(explorer.prepare.remote())
        ray.get(explorer.benchmark.remote())
        logger.info("Benchmark finished.")
        ray.get(explorer.shutdown.remote())
    except Exception:
        logger.error(f"Benchmark failed:\n{traceback.format_exc()}")


def explore(config: Config) -> None:
    """Run explorer."""
    try:
        explorer = Explorer.get_actor(config)
        ray.get(explorer.prepare.remote())
        ray.get(explorer.sync_weight.remote())
        ray.get(explorer.explore.remote())
        ray.get(explorer.shutdown.remote())
    except Exception:
        logger.error(f"Explorer failed:\n{traceback.format_exc()}")


def train(config: Config) -> None:
    """Run trainer."""
    try:
        trainer = Trainer.get_actor(config)
        ray.get(trainer.prepare.remote())
        ray.get(trainer.sync_weight.remote())
        ray.get(trainer.train.remote())
        ray.get(trainer.shutdown.remote())
    except Exception:
        logger.error(f"Trainer failed:\n{traceback.format_exc()}")


def both(config: Config) -> None:
    """Setup both explorer and trainer.

    For the explorer, a step contains `batch_size * sync_interval` number
    of rollout tasks.

    For the trainer, it has to consume all experiences generated by the explorer in
    the latest step. The specific number of experiences may vary for different
    algorithms and tasks.
    """
    try:
        explorer = Explorer.get_actor(config)
        trainer = Trainer.get_actor(config)
        ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
        ray.get(
            [
                explorer.prepare.remote(),
                trainer.prepare.remote(),
            ]
        )
        ray.get(
            [
                explorer.sync_weight.remote(),
                trainer.sync_weight.remote(),
            ]
        )
        ready_ref, wait_ref = ray.wait(
            [
                explorer.explore.remote(),
                trainer.train.remote(),
            ],
            num_returns=1,
        )

        ready = ray.get(ready_ref[0])
        if ready == config.trainer.name:
            logger.info(
                "===========================================================\n"
                "> Launcher detected that the `Trainer` process has finished.\n"
                "> Stopping the explorer process immediately.\n"
                "==========================================================="
            )
            ray.wait(wait_ref, timeout=5)
        elif ready == config.explorer.name:
            logger.info(
                "===============================================================\n"
                "> Launcher detected that the `Explorer` process has finished.\n"
                "> `Trainer` process may need to save the model checkpoint.\n"
                f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n"
                "> You can force stop the `Trainer` process by pressing Ctrl+C.\n"
                "==============================================================="
            )
            ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout)
        ray.wait(
            [explorer.shutdown.remote(), trainer.shutdown.remote()],
            timeout=config.synchronizer.sync_timeout,
            num_returns=2,
        )
    except Exception:
        logger.error(f"Explorer or Trainer failed:\n{traceback.format_exc()}")


def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
    if plugin_dir:
        os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir
    load_plugins()
    config = load_config(config_path)
    config.check_and_update()
    pprint(config)

    envs = {
        PLUGIN_DIRS_ENV_VAR: os.environ.get(PLUGIN_DIRS_ENV_VAR, ""),
        LOG_DIR_ENV_VAR: config.log.save_dir,
        LOG_LEVEL_ENV_VAR: config.log.level,
        LOG_NODE_IP_ENV_VAR: "1" if config.log.group_by_node else "0",
    }
    if dlc:
        setup_ray_cluster(namespace=config.ray_namespace, envs=envs)
    else:
        from trinity.utils.dlc_utils import is_running

        if not is_running:
            raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
        ray.init(
            namespace=config.ray_namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs}
        )

    # try to run task pipeline for raw data
    check_and_run_task_pipeline(config)

    try:
        if config.mode == "explore":
            explore(config)
        elif config.mode == "train":
            train(config)
        elif config.mode == "both":
            both(config)
        elif config.mode == "bench":
            bench(config)
    finally:
        if config.monitor.enable_ray_timeline:
            timeline_file = os.path.join(config.monitor.cache_dir, "timeline.json")
            logger.info(f"Exporting Ray timeline to {timeline_file}...")
            ray.timeline(filename=timeline_file)
            logger.info("Done. You can open the timeline file in `chrome://tracing`")

        if dlc:
            from trinity.utils.dlc_utils import stop_ray_cluster

            stop_ray_cluster(namespace=config.ray_namespace)


def studio(port: int = 8501):
    from streamlit.web import cli as stcli

    current_dir = Path(__file__).resolve().parent.parent
    config_manager_path = os.path.join(current_dir, "manager", "config_manager.py")

    sys.argv = [
        "streamlit",
        "run",
        config_manager_path,
        "--server.port",
        str(port),
        "--server.fileWatcherType",
        "none",
    ]
    sys.exit(stcli.main())


def main() -> None:
    """The main entrypoint."""
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="command", required=True)

    # run command
    run_parser = subparsers.add_parser("run", help="Run RFT process.")
    run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
    run_parser.add_argument(
        "--plugin-dir",
        type=str,
        default=None,
        help="Path to the directory containing plugin modules.",
    )
    run_parser.add_argument(
        "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
    )

    # studio command
    studio_parser = subparsers.add_parser("studio", help="Run studio.")
    studio_parser.add_argument(
        "--port", type=int, default=8501, help="The port for Trinity-Studio."
    )

    args = parser.parse_args()
    if args.command == "run":
        # TODO: support parse all args from command line
        run(args.config, args.dlc, args.plugin_dir)
    elif args.command == "studio":
        studio(args.port)


if __name__ == "__main__":
    main()
